
"""
按照代码编写的顺序，依次为：

1.载入词表，构建两个映射关系。
 word_to_id 将单词转化为id表示
 id_to_word 将id转化为单词
2.将token文件的形式变为：
dict {‘1234.jpg’: [‘4， 556， 44， 6， 57’, ‘2223， 4， 54， 221’]}
即： 一个字典，key是图像名称，value是一个列表，里面储存的是每一条图像描述信息。
3.载入图像特征
构建batch
为每一张图像，随机挑选出来一条描述
4.计算图构建
tf.placeholder(dtype,shape = NOne,name = None)_占位符，用于传入外部数据
四个placeholder，分别是：图像特征、对应描述、mask（代码中有介绍），drop值。
文本embedding、图像embedding
进入lstm结构，全连接
"""
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-

import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import math
import random

# 打印出 log
tf.logging.set_verbosity(tf.logging.INFO)


input_description_file = "./data/results_20130124.token"
input_img_feature_dir = './data/download_inception_v3_features'
input_vocab_file = './data/vocab.txt'
output_dir = './data/local_run'

if not gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)


def get_default_params():
    return tf.contrib.training.HParams(
        num_vocab_word_threshold=3,
        num_embedding_nodes=32,
        num_timesteps=10,
        num_lstm_nodes=[64, 64],
        num_lstm_layers=2,
        num_fc_nodes=32,
        batch_size=100,
        cell_type='lstm',
        clip_lstm_grads=1.0,
        learning_rate=0.001,
        keep_prob=0.8,
        log_frequent=500,
        save_frequent=5000,
    )

training_steps = 1000000

hps = get_default_params()


class Vocab(object):
    '''
    1.构建词表
    '''
    def __init__(self, filename, word_num_threshold):
        self._id_to_word = {} # 从 词id 到 单词 映射
        self._word_to_id = {} # 从 单词 到 词id 的映射
        self._unk = -1
        self._eos = -1
        self._word_num_threshold = word_num_threshold
        self._read_dict(filename) # 将 词表 读入 成 字典形式

    def _read_dict(self, filename):
        '''
        将 词表 读入 成 字典形式
        :param filename: 词表文件
        :return:
        '''
        with gfile.GFile(filename, 'r') as f:
            lines = f.readlines()
        for line in lines:
            # occurence 是 词频

            # strip()——去掉左右空格
            word, occurence = line.strip('\r\n').split('\t')
            occurence = int(occurence)
            if word != '<UNK>' and (occurence < self._word_num_threshold):
                continue
            # 按照 进入 字典 的 顺序排序
            idx = len(self._id_to_word)
            if word == '<UNK>':
                self._unk = idx
            elif word == '.':
                self._eos = idx
            if idx in self._id_to_word or word in self._word_to_id:
                raise Exception('duplicate words in vocab file')
            # 接下来 构建两个映射
            self._word_to_id[word] = idx
            self._id_to_word[idx] = word

    @property
    def unk(self):
        return self._unk

    @property
    def eos(self):
        return self._eos

    def word_to_id(self, word):
        '''
        单个单词 转化为 id 表示
        :param word: 单词名称
        :return: 词id
        '''
        return self._word_to_id.get(word, self.unk)

    def id_to_word(self, cur_id):
        '''
        词id 转化 为 单词
        :param cur_id:  词id
        :return: 单词
        '''
        return self._id_to_word.get(cur_id, '<UNK>')

    def size(self):
        # 词表 长度
        return len(self._word_to_id)

    def encode(self, sentence):
        '''
        将一个描述中的单词，映射成 id 表示
        :param sentence: 描述语句
        :return: 词id句子
        '''
        word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split(' ')]
        return word_ids

    def decode(self, sentence_id):
        '''
        将一个 id 句子，转化为 单词句子
        :param sentence_id:
        :return:
        '''
        words = [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)


def parse_token_file(token_file):
    '''
    解析token文件
    :param token_file: 文件路径
    :return: dict 形式如： {'1234.jpg': ['this is a people', 'the people is happy']}
    '''
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)
    return img_name_to_tokens


def convert_token_to_id(img_name_to_tokens, vocab):
    '''
    简单的说，就是在上一个函数出来的结果中，把描述文字 换成 id 表示
    :param img_name_to_tokens:
    :param vocab: 词表 字典
    :return: dict 形式如： {'1234.jpg': ['4 556 44 6757', '2223 4354 22 1']}
    '''
    img_name_to_token_ids = {}
    for img_name in img_name_to_tokens:
        img_name_to_token_ids.setdefault(img_name, [])
        descriptions = img_name_to_tokens[img_name]
        for description in descriptions:
            token_ids = vocab.encode(description)
            img_name_to_token_ids[img_name].append(token_ids)
    return img_name_to_token_ids


vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size() # 获得词表长度
logging.info("vocab_size: %d" % vocab_size)


img_name_to_tokens = parse_token_file(input_description_file)
# 图像 对应的 描述信息
img_name_to_token_ids = convert_token_to_id(img_name_to_tokens, vocab)


class ImageCaptionData(object):
    '''
    数据供应
    '''
    def __init__(self,
                 img_name_to_token_ids,
                 img_feature_dir,
                 num_timesteps,
                 vocab,
                 deterministic=False):
        '''

        :param img_name_to_token_ids: 图像到描述字典
        :param img_feature_dir: 图像特征 保存文件目录
        :param num_timesteps: 时间步的数量
        :param vocab: 词表
        :param deterministic: 是否打乱
        '''
        self._vocab = vocab
        self._all_img_feature_filepaths = [] # 拼接出　图像特征文件的　路径
        for filename in gfile.ListDirectory(img_feature_dir):
            self._all_img_feature_filepaths.append(os.path.join(img_feature_dir, filename))

        self._img_name_to_token_ids = img_name_to_token_ids
        self._num_timesteps = num_timesteps
        self._indicator = 0 # batch_size 的 起始点
        self._deterministic = deterministic
        self._img_feature_filenames = [] # 保存所有图像特征的路径
        self._img_feature_data = [] # 保存 所有 图像特征
        self._load_img_feature_pickle()
        if not self._deterministic:
            self._random_shuffle()

    def _load_img_feature_pickle(self):
        '''
        从 文件 从 读取 图像 特征
        :return:
        '''
        for filepath in self._all_img_feature_filepaths:
            with gfile.GFile(filepath, 'rb') as f:
                filenames, features = pickle.load(f, encoding='iso-8859-1')
                self._img_feature_filenames += filenames # 将列表拼接到一起
                self._img_feature_data.append(features) # 将 特征 保存到一起
        # 如 原来矩阵是 [#(1000, 1, 1, 2048), #(1000, 1, 1, 2048)] 合并之后为 (2000, 1, 1, 2048)
        self._img_feature_data = np.vstack(self._img_feature_data)
        origin_shape = self._img_feature_data.shape
        # 此刻 origin_shape 的 shape：(31783, 1, 1, 2048)
        self._img_feature_data = np.reshape( # 将其中的 两维度 去掉
            self._img_feature_data, (origin_shape[0], origin_shape[3]))
        self._img_feature_filenames = np.asarray(self._img_feature_filenames)
        print(self._img_feature_data.shape) # (31783, 2048)
        print(self._img_feature_filenames.shape) # (31783,)
        if not self._deterministic:
            self._random_shuffle()

    def size(self):
        # 图像文件的个数
        return len(self._img_feature_filenames)

    def img_feature_size(self):
        # 获得图像特征的维度
        return self._img_feature_data.shape[1]

    def _random_shuffle(self):
        p = np.random.permutation(self.size())
        self._img_feature_filenames = self._img_feature_filenames[p]
        self._img_feature_data = self._img_feature_data[p]

    def _img_desc(self, filenames):
        '''
        从多条语句中，随机获得一条描述
        :param filenames:
        :return:
        '''
        batch_sentence_ids = []
        batch_weights = []# 为最后 去掉无用的梯度做准备
        for filename in filenames:
            token_ids_set = self._img_name_to_token_ids[filename]
            chosen_token_ids = random.choice(token_ids_set) # 随机选取一个
            #chosen_token_ids = token_ids_set[0]
            chosen_token_length = len(chosen_token_ids)

            weight = [1 for i in range(chosen_token_length)]
            if chosen_token_length >= self._num_timesteps:
                chosen_token_ids = chosen_token_ids[0:self._num_timesteps]
                weight = weight[0:self._num_timesteps]
            else:# 否则 需要补零
                # 计算需要补零的个数
                remaining_length = self._num_timesteps - chosen_token_length
                chosen_token_ids += [self._vocab.eos for i in range(remaining_length)]
                weight += [0 for i in range(remaining_length)]
            batch_sentence_ids.append(chosen_token_ids)
            batch_weights.append(weight)
        batch_sentence_ids = np.asarray(batch_sentence_ids)
        batch_weights = np.asarray(batch_weights)
        # 此刻返回的是 batch 句子描述， 和 weights
        return batch_sentence_ids, batch_weights

    def next(self, batch_size):
        '''
                返回 batch_size 个数据
                流程如下：
                1. 得到 图像名称
                2. 得到 图像特征
                3. 得到 图像描述信息
                :param batch_size:
                :return:
                '''
        end_indicator = self._indicator + batch_size
        if end_indicator > self.size():
            if not self._deterministic:
                self._random_shuffle()
            self._indicator = 0
            end_indicator = self._indicator + batch_size
        assert end_indicator <= self.size()

        batch_img_features = self._img_feature_data[self._indicator: end_indicator]
        batch_img_names = self._img_feature_filenames[self._indicator: end_indicator]

        # batch_sentence_ids 是 图像描述 的id形式，
        # batch_weights 句子权重，sentence_ids:[100, 101, 102, 0, 0, 0]--->[1, 1, 1, 0, 0, 0]
        #   相当于是一个mask，和sentence_ids相乘，计算损失函数的时候，不去计算他们的损失
        batch_sentence_ids, batch_weights = self._img_desc(batch_img_names)

        self._indicator = end_indicator
        return batch_img_features, batch_sentence_ids, batch_weights, batch_img_names


caption_data = ImageCaptionData(img_name_to_token_ids, input_img_feature_dir, hps.num_timesteps, vocab)
img_feature_dim = caption_data.img_feature_size()

def create_rnn_cell(hidden_dim, cell_type):
    '''
    根据cell类型，返回相应的网络结构
    :param hidden_dim:
    :param cell_type:
    :return:
    '''
    if cell_type == 'lstm':
        return tf.contrib.rnn.BasicLSTMCell(hidden_dim, state_is_tuple=True)
    elif cell_type == 'gru':
        return tf.contrib.rnn.GRUCell(hidden_dim)
    else:
        raise Exception("%s has not been supported" % cell_type)


def dropout(cell, keep_prob):
    return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)


def get_train_model(hps, vocab_size, img_feature_dim):
    num_timesteps = hps.num_timesteps
    batch_size = hps.batch_size

    img_feature = tf.placeholder(tf.float32, (batch_size, img_feature_dim))
    sentence = tf.placeholder(tf.int32, (batch_size, num_timesteps))
    mask = tf.placeholder(tf.float32, (batch_size, num_timesteps))
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

    global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step', trainable=False)

    '''
        训练过程：
        句子：[a, b, c, d, e, f]

        真正的输入：[img, a, b, c, d, e]
        图像特征 [0.3, 0.5, 0.2, 0.9]
        predict #1 img_feature -> embedding_img -> (a)
        predict #2 a -> embedding_word -> lstm -> b
        predict #3 b ->                        -> c  
    '''
    # Sets up the embedding layer.
    embedding_initializer = tf.random_uniform_initializer(-1.0, 1.0)
    # tf.random_uniform_initializer() 生成具有均匀分布的张量的初始化器
    # 参考：https://www.w3cschool.cn/tensorflow_python/tensorflow_python-f1np2gyt.html
    with tf.variable_scope('embedding', initializer=embedding_initializer):
        embeddings = tf.get_variable(
            'embeddings',
            [vocab_size, hps.num_embedding_nodes],
            tf.float32)
        embed_token_ids = tf.nn.embedding_lookup(embeddings, sentence[:, 0:num_timesteps - 1])
        # 此刻 的 embed_token_ids 的 shape：[batch_size, num_timestep-1, num_embedding]

    # 对图像进行 embedding
    # 此刻的图像是一个 2048 的向量，需要进行一个全连接，转化成一个词embedding 长度一样的一个向量。
    # 这样就可以将 图像embedding 和 词 embedding 拼接到一起，用来做预测
    img_feature_embed_init = tf.uniform_unit_scaling_initializer(factor=1.0)
    # 参考链接：https://www.w3cschool.cn/tensorflow_python/tensorflow_python-fy6t2o0o.html
    with tf.variable_scope('image_feature_embed', initializer=img_feature_embed_init):
        # img_feature:[batch_size, img_feature_dim]
        # embed_img: [batch_size, num_embedding_nodes]
        embed_img = tf.layers.dense(img_feature, hps.num_embedding_nodes)
        embed_img = tf.expand_dims(embed_img, 1)
        # 此刻的 embed_inputs shape: [batch_size, num_timesteps, num_embedding_nodes]
        embed_inputs = tf.concat([embed_img, embed_token_ids], axis=1)

    # Sets up LSTM network.
    scale = 1.0 / math.sqrt(hps.num_embedding_nodes + hps.num_lstm_nodes[-1]) / 3.0
    lstm_init = tf.random_uniform_initializer(-scale, scale)
    with tf.variable_scope('lstm_nn', initializer=lstm_init):
        cells = []
        for i in range(hps.num_lstm_layers):
            cell = create_rnn_cell(hps.num_lstm_nodes[i], hps.cell_type)
            cell = dropout(cell, keep_prob)
            cells.append(cell)
        cell = tf.contrib.rnn.MultiRNNCell(cells)

        initial_state = cell.zero_state(hps.batch_size, tf.float32)
        # rnn_outputs: [batch_size, num_timesteps, hps.num_lstm_node[-1]]
        rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                           embed_inputs,
                                           initial_state=initial_state)

    # Sets up the fully-connected layer.
    fc_init = tf.uniform_unit_scaling_initializer(factor=1.0)
    with tf.variable_scope('fc', initializer=fc_init):
        # 因为要使用 rnn_outputs 做全连接，需要改变维度，保留最后一个维度不变，合并前两个维度
        rnn_outputs_2d = tf.reshape(rnn_outputs, [-1, hps.num_lstm_nodes[-1]])
        fc1 = tf.layers.dense(rnn_outputs_2d, hps.num_fc_nodes, name='fc1')
        fc1_dropout = tf.nn.dropout(fc1, keep_prob)
        fc1_dropout = tf.nn.relu(fc1_dropout)
        logits = tf.layers.dense(fc1_dropout, vocab_size, name='logits')
        # logits 是 整个词表的 概率分布
        # logits的 shape 是： (800, 10875) 800是batch_size*timesteps 10875是词表长度
        # 注意，在全链接中的dropout和在lstm中的dropout不同的
        # lstm tf.contrib.rnn.DropoutWrapper()


    with tf.variable_scope('loss'):
        # 因为在进入全连接之前，将第一维和第二维给展平了，所以，同样需要将GT给展平
        '''
        这里多做一点注释，以防以后忘掉
        因为在 进行 全连接之前，已经将数据reshape 成了二维，
        即 [
                [1.jpg的第1个timestep, lstm最后一层的个数],
                [1.jpg的第2个timestep, lstm最后一层的个数],
                ...
                [2.jpg的第1个timestep, lstm最后一层的个数],
                [2.jpg的第2个timestep, lstm最后一层的个数]
            ]
        这样，最终logits输出的是
            [1.jpg的第1个timestep预测值的概率分布，
             1.jpg的第2个timestep预测值的概率分布，
            ...
             2.jpg的第1个timestep预测值的概率分布，
            ]
        同样的， 将sentences进行reshape 之后，就成了
            [
                1.jpg的第1个timestep gt
                1.jpg的第2个timestep gt
                ...
                2.jpg的第1个timestep gt
                2.jpg的第2个timestep gt
            ]
        这样，正好可以 将 预测值 和 真实值 对上
        '''
        sentence_flatten = tf.reshape(sentence, [-1])
        mask_flatten = tf.reshape(mask, [-1])
        mask_sum = tf.reduce_sum(mask_flatten)
        softmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=sentence_flatten)
        weighted_softmax_loss = tf.multiply(softmax_loss,
                                            tf.cast(mask_flatten, tf.float32))
        # 该函数做了三件事儿：1.对logits进行softmax。2.对labels进行one-hot编码 3.计算交叉熵

        prediction = tf.argmax(logits, 1) # 得到预测值
        # 预测值 和 真实值 做比较
        correct_prediction = tf.equal(tf.cast(prediction,tf.int32), sentence_flatten)
        # 使用 mask 去掉 噪音
        correct_prediction_with_mask = tf.multiply(
            tf.cast(correct_prediction, tf.float32),
            mask_flatten)
        accuracy = tf.reduce_sum(correct_prediction_with_mask) / mask_sum
        loss = tf.reduce_sum(weighted_softmax_loss) / mask_sum
        tf.summary.scalar('loss', loss)

    with tf.variable_scope('train_op'):
        tvars = tf.trainable_variables()
        for var in tvars:
            logging.info("variable name: %s" % (var.name))
        grads, _ = tf.clip_by_global_norm( # 对梯度进行裁剪
            tf.gradients(loss, tvars), hps.clip_lstm_grads)
        for grad, var in zip(grads, tvars):
            tf.summary.histogram('%s_grad' % (var.name), grad)
        optimizer = tf.train.AdamOptimizer(hps.learning_rate)
        train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)

    return ((img_feature, sentence, mask, keep_prob),
            (loss, accuracy, train_op),
            global_step)


placeholders, metrics, global_step = get_train_model(hps, vocab_size, img_feature_dim)
img_feature, sentence, mask, keep_prob = placeholders
loss, accuracy, train_op = metrics

summary_op = tf.summary.merge_all()

init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=10)

with tf.Session() as sess:
    sess.run(init_op)
    writer = tf.summary.FileWriter(output_dir, sess.graph)
    for i in range(training_steps):
        batch_img_features, batch_sentence_ids, batch_weights, _ = caption_data.next(hps.batch_size)
        input_vals = (batch_img_features, batch_sentence_ids, batch_weights, hps.keep_prob)

        feed_dict = dict(zip(placeholders, input_vals))
        fetches = [global_step, loss, accuracy, train_op]

        should_log = (i + 1) % hps.log_frequent == 0
        should_save = (i + 1) % hps.save_frequent == 0
        if should_log:
            fetches += [summary_op]
        outputs = sess.run(fetches, feed_dict)
        global_step_val, loss_val, accuracy_val = outputs[0:3]
        if should_log:
            summary_str = outputs[4]
            writer.add_summary(summary_str, global_step_val)
            logging.info('Step: %5d, loss: %3.3f, accuracy: %3.3f'
                         % (global_step_val, loss_val, accuracy_val))
        if should_save:
            logging.info("Step: %d, image caption model saved" % (global_step_val))
            saver.save(sess, os.path.join(output_dir, "image_caption"), global_step=global_step_val)

